#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Class definition for the OFDM Demodulator"""

import numpy as np
import tensorflow as tf
from tensorflow.signal import fftshift

from sionna.phy import Block, PI
from sionna.phy.utils import expand_to_rank
from sionna.phy.signal import fft

class OFDMDemodulator(Block):
    # pylint: disable=line-too-long
    r"""
    Computes the frequency-domain representation of an OFDM waveform
    with cyclic prefix removal

    The demodulator assumes that the input sequence is generated by the
    :class:`~sionna.phy.channel.TimeChannel`. For a single pair of antennas,
    the received signal sequence is given as:

    .. math::

        y_b = \sum_{\ell =L_\text{min}}^{L_\text{max}} \bar{h}_\ell x_{b-\ell} + w_b, \quad b \in[L_\text{min}, N_B+L_\text{max}-1]

    where :math:`\bar{h}_\ell` are the discrete-time channel taps,
    :math:`x_{b}` is the the transmitted signal,
    and :math:`w_\ell` Gaussian noise.

    Starting from the first symbol, the demodulator cuts the input
    sequence into pieces of size ``cyclic_prefix_length + fft_size``,
    and throws away any trailing symbols. For each piece, the cyclic
    prefix is removed and the ``fft_size``-point discrete Fourier
    transform is computed. It is also possible that every OFDM symbol
    has a cyclic prefix of different length.

    Since the input sequence starts at time :math:`L_\text{min}`,
    the FFT-window has a timing offset of :math:`L_\text{min}` symbols,
    which leads to a subcarrier-dependent phase shift of
    :math:`e^{\frac{j2\pi k L_\text{min}}{N}}`, where :math:`k`
    is the subcarrier index, :math:`N` is the FFT size,
    and :math:`L_\text{min} \le 0` is the largest negative time lag of
    the discrete-time channel impulse response. This phase shift
    is removed in this layer, by explicitly multiplying
    each subcarrier by  :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`.
    This is a very important step to enable channel estimation with
    sparse pilot patterns that needs to interpolate the channel frequency
    response accross subcarriers. It also ensures that the
    channel frequency response `seen` by the time-domain channel
    is close to the :class:`~sionna.phy.channel.OFDMChannel`.

    Parameters
    ----------
    fft_size : `int`
        FFT size (, i.e., the number of subcarriers).

    l_min : `int`
        The largest negative time lag of the discrete-time channel
        impulse response. It should be the same value as that used by the
        `cir_to_time_channel` function.

    cyclic_prefix_length : `int` (default 0) | [num_ofdm_symbols], `int`
        Integer or vector of integers indicating the length of the
        cyclic prefix that is prepended to each OFDM symbol. None of its
        elements can be larger than the FFT size.

    precision : `None` (default) | "single" | "double"
        Precision used for internal calculations and outputs.
        If set to `None`,
        :attr:`~sionna.phy.config.Config.precision` is used.

    Input
    -----
    :[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)+n] or [...,num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)+n], `tf.complex`
        Tensor containing the time-domain signal along the last dimension.
        `n` is a nonnegative integer.

    Output
    ------
    :[...,num_ofdm_symbols,fft_size], `tf.complex`
        Tensor containing the OFDM resource grid along the last
        two dimension.
    """
    def __init__(self, fft_size, l_min, cyclic_prefix_length=0,
                 precision=None, **kwargs):
        super().__init__(precision=precision, **kwargs)
        self._fft_size = None
        self._l_min = None
        self._cyclic_prefix_length = None
        self.fft_size = fft_size
        self.l_min = l_min
        self.cyclic_prefix_length = cyclic_prefix_length

    @property
    def fft_size(self):
        return self._fft_size

    @fft_size.setter
    def fft_size(self, value):
        assert value>0, "`fft_size` must be positive."
        self._fft_size = int(value)

    @property
    def l_min(self):
        return self._l_min

    @l_min.setter
    def l_min(self, value):
        assert value<=0, "l_min must be nonpositive."
        self._l_min = int(value)

    @property
    def cyclic_prefix_length(self):
        return self._cyclic_prefix_length

    @cyclic_prefix_length.setter
    def cyclic_prefix_length(self, value):
        value = tf.cast(value, tf.int32)
        if not tf.reduce_all(value>=0):
            msg = "`cyclic_prefix_length` must be nonnegative."
            raise ValueError(msg)
        if not 0<= tf.rank(value)<=1:
            msg = "`cyclic_prefix_length` must be of rank 0 or 1"
            raise ValueError(msg)
        self._cyclic_prefix_length = value

    def build(self, input_shape): # pylint: disable=unused-argument
        # Compute phase correction terms to to channel
        tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \
              / tf.cast(self.fft_size, tf.float32) \
              * tf.range(self.fft_size, dtype=tf.float32)
        self._phase_compensation = tf.exp(tf.complex(0., tmp))

        if len(self.cyclic_prefix_length.shape)==0:
            # Compute number of elements that will be truncated
            self._rest = np.mod(input_shape[-1],
                                    self.fft_size + self.cyclic_prefix_length)

            # Compute number of full OFDM symbols to be demodulated
            self._num_ofdm_symbols = np.floor_divide(
                                    input_shape[-1]-self._rest,
                                    self.fft_size + self.cyclic_prefix_length)
        else:
            # Deal with individual cp lengths for OFDM symbols
            # Compute the relevant indices to gather for
            # every OFDM symbol from the time domain input
            num_ofdm_symbols = self.cyclic_prefix_length.shape[0]
            row_lengths = self.cyclic_prefix_length + self.fft_size
            offsets = tf.math.cumsum(tf.concat([[0], row_lengths],
                                               axis=0)[:-1])
            offsets = tf.expand_dims(offsets, 1)
            ind = tf.repeat(tf.range(start=0,
                                     limit=self.fft_size)[tf.newaxis,:],
                            repeats=num_ofdm_symbols, axis=0)
            ind += self.cyclic_prefix_length[:, tf.newaxis]
            ind += offsets
            # [num_ofdm_symbols, fft_size]
            self._ind = ind

    def call(self, inputs):
        """Demodulate OFDM waveform onto a resource grid

        Args:
            inputs (`tf.complex`):
                `[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)]`.

        Returns:
            `tf.complex` : The demodulated inputs of shape
            `[...,num_ofdm_symbols, fft_size]`.
        """
        if len(self.cyclic_prefix_length.shape)==0:
            # Same CP length for all OFDM symbols
            # Cut last samples that do not fit into an OFDM symbol
            inputs = inputs if self._rest==0 else inputs[...,:-self._rest]

            # Reshape input to separate OFDM symbols
            new_shape = tf.concat(
                            [tf.shape(inputs)[:-1],
                            [self._num_ofdm_symbols],
                            [self.fft_size + self.cyclic_prefix_length]], 0)
            x = tf.reshape(inputs, new_shape)

            # Remove cyclic prefix
            x = x[...,self.cyclic_prefix_length:]

        else:
            # Individual CP length for OFDM symbols
            x = tf.gather(inputs, self._ind, axis=-1)

        # Compute FFT
        x = fft(x)

        # Apply phase shift compensation to all subcarriers
        rot = tf.cast(self._phase_compensation, x.dtype)
        rot = expand_to_rank(rot, tf.rank(x), 0)
        x = x * rot

        # Shift DC subcarrier to the middle
        x = fftshift(x, axes=-1)

        return x
